{
 "cells": [
  {
   "cell_type": "code",
   "id": "initial_id",
   "metadata": {
    "collapsed": true,
    "ExecuteTime": {
     "end_time": "2025-02-20T11:28:27.653809Z",
     "start_time": "2025-02-20T11:28:27.646808Z"
    }
   },
   "source": [
    "import matplotlib as mpl\n",
    "import matplotlib.pyplot as plt\n",
    "%matplotlib inline\n",
    "import numpy as np\n",
    "import sklearn\n",
    "import pandas as pd\n",
    "import os\n",
    "import sys\n",
    "import time\n",
    "\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "\n",
    "print(sys.version_info)\n",
    "for module in mpl, np, pd, sklearn, torch:\n",
    "    print(module.__name__, module.__version__)\n",
    "    \n",
    "device = torch.device(\"cuda:0\") if torch.cuda.is_available() else torch.device(\"cpu\")\n",
    "print(device)\n"
   ],
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "sys.version_info(major=3, minor=12, micro=3, releaselevel='final', serial=0)\n",
      "matplotlib 3.10.0\n",
      "numpy 1.26.4\n",
      "pandas 2.2.3\n",
      "sklearn 1.6.1\n",
      "torch 2.6.0+cu118\n",
      "cuda:0\n"
     ]
    }
   ],
   "execution_count": 7
  },
  {
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-02-20T11:28:27.713565Z",
     "start_time": "2025-02-20T11:28:27.654852Z"
    }
   },
   "cell_type": "code",
   "source": [
    "from torchvision import datasets\n",
    "from torchvision.transforms import ToTensor\n",
    "from torchvision import transforms\n",
    "\n",
    "\n",
    "# 定义数据集的变换\n",
    "transform = transforms.Compose([\n",
    "    transforms.ToTensor()\n",
    "])\n",
    "# fashion_mnist图像分类数据集，衣服分类，60000张训练图片，10000张测试图片\n",
    "train_ds = datasets.FashionMNIST(\n",
    "    root=\"data\",\n",
    "    train=True,\n",
    "    download=True,\n",
    "    transform=transform\n",
    ")\n",
    "\n",
    "test_ds = datasets.FashionMNIST(\n",
    "    root=\"data\",\n",
    "    train=False,\n",
    "    download=True,\n",
    "    transform=transform\n",
    ")\n",
    "\n",
    "# torchvision 数据集里没有提供训练集和验证集的划分\n",
    "# 当然也可以用 torch.utils.data.Dataset 实现人为划分"
   ],
   "id": "49248fca4413a3a1",
   "outputs": [],
   "execution_count": 8
  },
  {
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-02-20T11:28:27.717339Z",
     "start_time": "2025-02-20T11:28:27.714608Z"
    }
   },
   "cell_type": "code",
   "source": "",
   "id": "e205da284596a539",
   "outputs": [],
   "execution_count": 8
  },
  {
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-02-20T11:28:27.723617Z",
     "start_time": "2025-02-20T11:28:27.719260Z"
    }
   },
   "cell_type": "code",
   "source": "img, label = train_ds[0]",
   "id": "4d22de1235004660",
   "outputs": [],
   "execution_count": 9
  },
  {
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-02-20T11:28:27.734145Z",
     "start_time": "2025-02-20T11:28:27.724635Z"
    }
   },
   "cell_type": "code",
   "source": "img",
   "id": "17cfe2fd5bfaa850",
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,\n",
       "          0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,\n",
       "          0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,\n",
       "          0.0000, 0.0000, 0.0000, 0.0000],\n",
       "         [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,\n",
       "          0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,\n",
       "          0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,\n",
       "          0.0000, 0.0000, 0.0000, 0.0000],\n",
       "         [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,\n",
       "          0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,\n",
       "          0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,\n",
       "          0.0000, 0.0000, 0.0000, 0.0000],\n",
       "         [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,\n",
       "          0.0000, 0.0000, 0.0000, 0.0000, 0.0039, 0.0000, 0.0000, 0.0510,\n",
       "          0.2863, 0.0000, 0.0000, 0.0039, 0.0157, 0.0000, 0.0000, 0.0000,\n",
       "          0.0000, 0.0039, 0.0039, 0.0000],\n",
       "         [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,\n",
       "          0.0000, 0.0000, 0.0000, 0.0000, 0.0118, 0.0000, 0.1412, 0.5333,\n",
       "          0.4980, 0.2431, 0.2118, 0.0000, 0.0000, 0.0000, 0.0039, 0.0118,\n",
       "          0.0157, 0.0000, 0.0000, 0.0118],\n",
       "         [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,\n",
       "          0.0000, 0.0000, 0.0000, 0.0000, 0.0235, 0.0000, 0.4000, 0.8000,\n",
       "          0.6902, 0.5255, 0.5647, 0.4824, 0.0902, 0.0000, 0.0000, 0.0000,\n",
       "          0.0000, 0.0471, 0.0392, 0.0000],\n",
       "         [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,\n",
       "          0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.6078, 0.9255,\n",
       "          0.8118, 0.6980, 0.4196, 0.6118, 0.6314, 0.4275, 0.2510, 0.0902,\n",
       "          0.3020, 0.5098, 0.2824, 0.0588],\n",
       "         [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,\n",
       "          0.0000, 0.0000, 0.0000, 0.0039, 0.0000, 0.2706, 0.8118, 0.8745,\n",
       "          0.8549, 0.8471, 0.8471, 0.6392, 0.4980, 0.4745, 0.4784, 0.5725,\n",
       "          0.5529, 0.3451, 0.6745, 0.2588],\n",
       "         [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,\n",
       "          0.0000, 0.0039, 0.0039, 0.0039, 0.0000, 0.7843, 0.9098, 0.9098,\n",
       "          0.9137, 0.8980, 0.8745, 0.8745, 0.8431, 0.8353, 0.6431, 0.4980,\n",
       "          0.4824, 0.7686, 0.8980, 0.0000],\n",
       "         [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,\n",
       "          0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.7176, 0.8824, 0.8471,\n",
       "          0.8745, 0.8941, 0.9216, 0.8902, 0.8784, 0.8706, 0.8784, 0.8667,\n",
       "          0.8745, 0.9608, 0.6784, 0.0000],\n",
       "         [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,\n",
       "          0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.7569, 0.8941, 0.8549,\n",
       "          0.8353, 0.7765, 0.7059, 0.8314, 0.8235, 0.8275, 0.8353, 0.8745,\n",
       "          0.8627, 0.9529, 0.7922, 0.0000],\n",
       "         [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,\n",
       "          0.0000, 0.0039, 0.0118, 0.0000, 0.0471, 0.8588, 0.8627, 0.8314,\n",
       "          0.8549, 0.7529, 0.6627, 0.8902, 0.8157, 0.8549, 0.8784, 0.8314,\n",
       "          0.8863, 0.7725, 0.8196, 0.2039],\n",
       "         [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,\n",
       "          0.0000, 0.0000, 0.0235, 0.0000, 0.3882, 0.9569, 0.8706, 0.8627,\n",
       "          0.8549, 0.7961, 0.7765, 0.8667, 0.8431, 0.8353, 0.8706, 0.8627,\n",
       "          0.9608, 0.4667, 0.6549, 0.2196],\n",
       "         [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,\n",
       "          0.0000, 0.0157, 0.0000, 0.0000, 0.2157, 0.9255, 0.8941, 0.9020,\n",
       "          0.8941, 0.9412, 0.9098, 0.8353, 0.8549, 0.8745, 0.9176, 0.8510,\n",
       "          0.8510, 0.8196, 0.3608, 0.0000],\n",
       "         [0.0000, 0.0000, 0.0039, 0.0157, 0.0235, 0.0275, 0.0078, 0.0000,\n",
       "          0.0000, 0.0000, 0.0000, 0.0000, 0.9294, 0.8863, 0.8510, 0.8745,\n",
       "          0.8706, 0.8588, 0.8706, 0.8667, 0.8471, 0.8745, 0.8980, 0.8431,\n",
       "          0.8549, 1.0000, 0.3020, 0.0000],\n",
       "         [0.0000, 0.0118, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,\n",
       "          0.0000, 0.2431, 0.5686, 0.8000, 0.8941, 0.8118, 0.8353, 0.8667,\n",
       "          0.8549, 0.8157, 0.8275, 0.8549, 0.8784, 0.8745, 0.8588, 0.8431,\n",
       "          0.8784, 0.9569, 0.6235, 0.0000],\n",
       "         [0.0000, 0.0000, 0.0000, 0.0000, 0.0706, 0.1725, 0.3216, 0.4196,\n",
       "          0.7412, 0.8941, 0.8627, 0.8706, 0.8510, 0.8863, 0.7843, 0.8039,\n",
       "          0.8275, 0.9020, 0.8784, 0.9176, 0.6902, 0.7373, 0.9804, 0.9725,\n",
       "          0.9137, 0.9333, 0.8431, 0.0000],\n",
       "         [0.0000, 0.2235, 0.7333, 0.8157, 0.8784, 0.8667, 0.8784, 0.8157,\n",
       "          0.8000, 0.8392, 0.8157, 0.8196, 0.7843, 0.6235, 0.9608, 0.7569,\n",
       "          0.8078, 0.8745, 1.0000, 1.0000, 0.8667, 0.9176, 0.8667, 0.8275,\n",
       "          0.8627, 0.9098, 0.9647, 0.0000],\n",
       "         [0.0118, 0.7922, 0.8941, 0.8784, 0.8667, 0.8275, 0.8275, 0.8392,\n",
       "          0.8039, 0.8039, 0.8039, 0.8627, 0.9412, 0.3137, 0.5882, 1.0000,\n",
       "          0.8980, 0.8667, 0.7373, 0.6039, 0.7490, 0.8235, 0.8000, 0.8196,\n",
       "          0.8706, 0.8941, 0.8824, 0.0000],\n",
       "         [0.3843, 0.9137, 0.7765, 0.8235, 0.8706, 0.8980, 0.8980, 0.9176,\n",
       "          0.9765, 0.8627, 0.7608, 0.8431, 0.8510, 0.9451, 0.2549, 0.2863,\n",
       "          0.4157, 0.4588, 0.6588, 0.8588, 0.8667, 0.8431, 0.8510, 0.8745,\n",
       "          0.8745, 0.8784, 0.8980, 0.1137],\n",
       "         [0.2941, 0.8000, 0.8314, 0.8000, 0.7569, 0.8039, 0.8275, 0.8824,\n",
       "          0.8471, 0.7255, 0.7725, 0.8078, 0.7765, 0.8353, 0.9412, 0.7647,\n",
       "          0.8902, 0.9608, 0.9373, 0.8745, 0.8549, 0.8314, 0.8196, 0.8706,\n",
       "          0.8627, 0.8667, 0.9020, 0.2627],\n",
       "         [0.1882, 0.7961, 0.7176, 0.7608, 0.8353, 0.7725, 0.7255, 0.7451,\n",
       "          0.7608, 0.7529, 0.7922, 0.8392, 0.8588, 0.8667, 0.8627, 0.9255,\n",
       "          0.8824, 0.8471, 0.7804, 0.8078, 0.7294, 0.7098, 0.6941, 0.6745,\n",
       "          0.7098, 0.8039, 0.8078, 0.4510],\n",
       "         [0.0000, 0.4784, 0.8588, 0.7569, 0.7020, 0.6706, 0.7176, 0.7686,\n",
       "          0.8000, 0.8235, 0.8353, 0.8118, 0.8275, 0.8235, 0.7843, 0.7686,\n",
       "          0.7608, 0.7490, 0.7647, 0.7490, 0.7765, 0.7529, 0.6902, 0.6118,\n",
       "          0.6549, 0.6941, 0.8235, 0.3608],\n",
       "         [0.0000, 0.0000, 0.2902, 0.7412, 0.8314, 0.7490, 0.6863, 0.6745,\n",
       "          0.6863, 0.7098, 0.7255, 0.7373, 0.7412, 0.7373, 0.7569, 0.7765,\n",
       "          0.8000, 0.8196, 0.8235, 0.8235, 0.8275, 0.7373, 0.7373, 0.7608,\n",
       "          0.7529, 0.8471, 0.6667, 0.0000],\n",
       "         [0.0078, 0.0000, 0.0000, 0.0000, 0.2588, 0.7843, 0.8706, 0.9294,\n",
       "          0.9373, 0.9490, 0.9647, 0.9529, 0.9569, 0.8667, 0.8627, 0.7569,\n",
       "          0.7490, 0.7020, 0.7137, 0.7137, 0.7098, 0.6902, 0.6510, 0.6588,\n",
       "          0.3882, 0.2275, 0.0000, 0.0000],\n",
       "         [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.1569,\n",
       "          0.2392, 0.1725, 0.2824, 0.1608, 0.1373, 0.0000, 0.0000, 0.0000,\n",
       "          0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,\n",
       "          0.0000, 0.0000, 0.0000, 0.0000],\n",
       "         [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,\n",
       "          0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,\n",
       "          0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,\n",
       "          0.0000, 0.0000, 0.0000, 0.0000],\n",
       "         [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,\n",
       "          0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,\n",
       "          0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,\n",
       "          0.0000, 0.0000, 0.0000, 0.0000]]])"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "execution_count": 10
  },
  {
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-02-20T11:28:27.740336Z",
     "start_time": "2025-02-20T11:28:27.734666Z"
    }
   },
   "cell_type": "code",
   "source": [
    "# 显示图片，这里需要把transforms.ToTensor(),进行归一化注释掉，否则是不行的\n",
    "def show_img_content(img):\n",
    "    from PIL import Image\n",
    "\n",
    "    # 打开一个图像文件\n",
    "    # img = Image.open(img)\n",
    "\n",
    "\n",
    "    print(\"图像大小:\", img.size)\n",
    "    print(\"图像模式:\", img.mode)\n",
    "\n",
    "\n",
    "    # 如果图像是单通道的，比如灰度图，你可以这样获取像素值列表：\n",
    "    if img.mode == 'L':\n",
    "        pixel_values = list(img.getdata())\n",
    "        print(pixel_values)\n",
    "show_img_content(img) #这里必须把上面的 transforms.ToTensor(), # 转换为tensor，进行归一化注释掉，否则是不行的"
   ],
   "id": "f568d50e8bbbcace",
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "图像大小: <built-in method size of Tensor object at 0x000001CBAA89E940>\n",
      "图像模式: <built-in method mode of Tensor object at 0x000001CBAA89E940>\n"
     ]
    }
   ],
   "execution_count": 11
  },
  {
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-02-20T11:28:59.932045Z",
     "start_time": "2025-02-20T11:28:51.689014Z"
    }
   },
   "cell_type": "code",
   "source": [
    "#计算均值和方差\n",
    "def cal_mean_std(ds):\n",
    "    mean = 0.\n",
    "    std = 0.\n",
    "    for img, _ in ds:  # 遍历每张图片,img.shape=[1,28,28]\n",
    "        mean += img.mean(dim=(1, 2))  # 计算每张图片的均值，dim=(1, 2)表示计算每张图片的每一个像素的均值,行列共同求均值\n",
    "        std += img.std(dim=(1, 2))\n",
    "    mean /= len(ds)\n",
    "    std /= len(ds)\n",
    "    return mean, std\n",
    "\n",
    "\n",
    "print(cal_mean_std(train_ds))\n"
   ],
   "id": "c2b5954cb0b450ae",
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(tensor([0.2860]), tensor([0.3205]))\n"
     ]
    }
   ],
   "execution_count": 13
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 2
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython2",
   "version": "2.7.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
